{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Introduction to SPLADE with FastEmbed\n",
    "\n",
    "In this notebook, we will explore how to generate Sparse Vectors -- in particular a variant of the [SPLADE](https://arxiv.org/abs/2107.05720).\n",
    "\n",
    "> 💡 The original [naver/SPLADE](https://github.com/naver/splade) models were licensed CC BY-NC-SA 4.0 -- Not for Commercial Use. This [SPLADE++](https://huggingface.co/prithivida/Splade_PP_en_v1) model is Apache License and hence, licensed for commercial use. \n",
    "\n",
    "## Outline:\n",
    "1. [What is SPLADE?](#What-is-SPLADE?)\n",
    "2. [Setting up the environment](#Setting-up-the-environment)\n",
    "3. [Generating SPLADE vectors with FastEmbed](#Generating-SPLADE-vectors-with-FastEmbed)\n",
    "4. [Understanding SPLADE vectors](#Understanding-SPLADE-vectors)\n",
    "5. [Observations and Design Choices](#Observations-and-Model-Design-Choices)\n",
    "\n",
    "\n",
    "## What is SPLADE?\n",
    "\n",
    "SPLADE was a novel method for _learning_ sparse vectors for text representation. This model beats BM25 -- the underlying approach for the Elastic/Lucene family of implementations. Thus making it highly effective for tasks such as information retrieval, document classification, and more. \n",
    "\n",
    "The key advantage of SPLADE is its ability to generate sparse vectors, which are more efficient and interpretable than dense vectors. This makes SPLADE a powerful tool for handling large-scale text data.\n",
    "\n",
    "## Setting up the environment\n",
    "\n",
    "This notebook uses few dependencies, which are installed below: "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "# !pip install -q fastembed"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Let's get started! 🚀"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2024-03-30T00:49:20.516644Z",
     "start_time": "2024-03-30T00:49:20.188543Z"
    }
   },
   "outputs": [],
   "source": [
    "from fastembed import SparseTextEmbedding, SparseEmbedding"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "> You can find the list of all supported Sparse Embedding models by calling this API: `SparseTextEmbedding.list_supported_models()`"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2024-03-30T00:49:22.366294Z",
     "start_time": "2024-03-30T00:49:22.362384Z"
    }
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "[{'model': 'prithvida/Splade_PP_en_v1',\n",
       "  'vocab_size': 30522,\n",
       "  'description': 'Misspelled version of the model. Retained for backward compatibility. Independent Implementation of SPLADE++ Model for English',\n",
       "  'size_in_GB': 0.532,\n",
       "  'sources': {'hf': 'Qdrant/SPLADE_PP_en_v1'}},\n",
       " {'model': 'prithivida/Splade_PP_en_v1',\n",
       "  'vocab_size': 30522,\n",
       "  'description': 'Independent Implementation of SPLADE++ Model for English',\n",
       "  'size_in_GB': 0.532,\n",
       "  'sources': {'hf': 'Qdrant/SPLADE_PP_en_v1'}}]"
      ]
     },
     "execution_count": 3,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "SparseTextEmbedding.list_supported_models()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2024-03-30T00:49:27.193530Z",
     "start_time": "2024-03-30T00:49:26.139248Z"
    }
   },
   "outputs": [
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "2aa47b26ab01475e8d3577433037f685",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Fetching 9 files:   0%|          | 0/9 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "model_name = \"prithvida/Splade_PP_en_v1\"\n",
    "# This triggers the model download\n",
    "model = SparseTextEmbedding(model_name=model_name)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2024-03-30T00:49:28.624109Z",
     "start_time": "2024-03-30T00:49:28.399960Z"
    }
   },
   "outputs": [],
   "source": [
    "documents: list[str] = [\n",
    "    \"Chandrayaan-3 is India's third lunar mission\",\n",
    "    \"It aimed to land a rover on the Moon's surface - joining the US, China and Russia\",\n",
    "    \"The mission is a follow-up to Chandrayaan-2, which had partial success\",\n",
    "    \"Chandrayaan-3 will be launched by the Indian Space Research Organisation (ISRO)\",\n",
    "    \"The estimated cost of the mission is around $35 million\",\n",
    "    \"It will carry instruments to study the lunar surface and atmosphere\",\n",
    "    \"Chandrayaan-3 landed on the Moon's surface on 23rd August 2023\",\n",
    "    \"It consists of a lander named Vikram and a rover named Pragyan similar to Chandrayaan-2. Its propulsion module would act like an orbiter.\",\n",
    "    \"The propulsion module carries the lander and rover configuration until the spacecraft is in a 100-kilometre (62 mi) lunar orbit\",\n",
    "    \"The mission used GSLV Mk III rocket for its launch\",\n",
    "    \"Chandrayaan-3 was launched from the Satish Dhawan Space Centre in Sriharikota\",\n",
    "    \"Chandrayaan-3 was launched earlier in the year 2023\",\n",
    "]\n",
    "sparse_embeddings_list: list[SparseEmbedding] = list(\n",
    "    model.embed(documents, batch_size=6)\n",
    ")  # batch_size is optional, notice the generator"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2024-03-30T00:49:29.646340Z",
     "start_time": "2024-03-30T00:49:29.643411Z"
    }
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "SparseEmbedding(values=array([0.05297208, 0.01963477, 0.36459631, 1.38508618, 0.71776593,\n",
       "       0.12667948, 0.46230844, 0.446771  , 0.26897505, 1.01519883,\n",
       "       1.5655334 , 0.29412213, 1.53102326, 0.59785569, 1.1001817 ,\n",
       "       0.02079751, 0.09955651, 0.44249091, 0.09747757, 1.53519952,\n",
       "       1.36765671, 0.15740395, 0.49882549, 0.38629025, 0.76612782,\n",
       "       1.25805044, 0.39058095, 0.27236196, 0.45152301, 0.48262018,\n",
       "       0.26085234, 1.35912788, 0.70710695, 1.71639752]), indices=array([ 1010,  1011,  1016,  1017,  2001,  2018,  2034,  2093,  2117,\n",
       "        2319,  2353,  2509,  2634,  2686,  2796,  2817,  2922,  2959,\n",
       "        3003,  3148,  3260,  3390,  3462,  3523,  3822,  4231,  4316,\n",
       "        4774,  5590,  5871,  6416, 11926, 12076, 16469]))"
      ]
     },
     "execution_count": 6,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "index = 0\n",
    "sparse_embeddings_list[index]"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "The previous output is a SparseEmbedding object for the first document in our list.\n",
    "\n",
    "It contains two arrays: values and indices. \n",
    "- The 'values' array represents the weights of the features (tokens) in the document.\n",
    "- The 'indices' array represents the indices of these features in the model's vocabulary.\n",
    "\n",
    "Each pair of corresponding values and indices represents a token and its weight in the document."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2024-03-30T00:49:31.549533Z",
     "start_time": "2024-03-30T00:49:31.546398Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Token at index 1010 has weight 0.05297207832336426\n",
      "Token at index 1011 has weight 0.01963476650416851\n",
      "Token at index 1016 has weight 0.36459630727767944\n",
      "Token at index 1017 has weight 1.385086178779602\n",
      "Token at index 2001 has weight 0.7177659273147583\n"
     ]
    }
   ],
   "source": [
    "# Let's print the first 5 features and their weights for better understanding.\n",
    "for i in range(5):\n",
    "    print(\n",
    "        f\"Token at index {sparse_embeddings_list[0].indices[i]} has weight {sparse_embeddings_list[0].values[i]}\"\n",
    "    )"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Understanding SPLADE vectors\n",
    "\n",
    "This is still a little abstract, so let's use the tokenizer vocab to make sense of these indices."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2024-03-30T00:49:36.203640Z",
     "start_time": "2024-03-30T00:49:34.889654Z"
    }
   },
   "outputs": [],
   "source": [
    "import json\n",
    "from transformers import AutoTokenizer\n",
    "\n",
    "tokenizer = AutoTokenizer.from_pretrained(\n",
    "    SparseTextEmbedding.list_supported_models()[0][\"sources\"][\"hf\"]\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2024-03-30T00:49:36.210049Z",
     "start_time": "2024-03-30T00:49:36.206825Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "{\n",
      "    \"chandra\": 1.7163975238800049,\n",
      "    \"third\": 1.5655333995819092,\n",
      "    \"##ya\": 1.535199522972107,\n",
      "    \"india\": 1.5310232639312744,\n",
      "    \"3\": 1.385086178779602,\n",
      "    \"mission\": 1.3676567077636719,\n",
      "    \"lunar\": 1.3591278791427612,\n",
      "    \"moon\": 1.2580504417419434,\n",
      "    \"indian\": 1.1001816987991333,\n",
      "    \"##an\": 1.015198826789856,\n",
      "    \"3rd\": 0.7661278247833252,\n",
      "    \"was\": 0.7177659273147583,\n",
      "    \"spacecraft\": 0.7071069478988647,\n",
      "    \"space\": 0.5978556871414185,\n",
      "    \"flight\": 0.4988254904747009,\n",
      "    \"satellite\": 0.4826201796531677,\n",
      "    \"first\": 0.46230843663215637,\n",
      "    \"expedition\": 0.4515230059623718,\n",
      "    \"three\": 0.4467709958553314,\n",
      "    \"fourth\": 0.44249090552330017,\n",
      "    \"vehicle\": 0.390580952167511,\n",
      "    \"iii\": 0.3862902522087097,\n",
      "    \"2\": 0.36459630727767944,\n",
      "    \"##3\": 0.2941221296787262,\n",
      "    \"planet\": 0.27236196398735046,\n",
      "    \"second\": 0.26897504925727844,\n",
      "    \"missions\": 0.2608523368835449,\n",
      "    \"launched\": 0.15740394592285156,\n",
      "    \"had\": 0.12667948007583618,\n",
      "    \"largest\": 0.09955651313066483,\n",
      "    \"leader\": 0.09747757017612457,\n",
      "    \",\": 0.05297207832336426,\n",
      "    \"study\": 0.02079751156270504,\n",
      "    \"-\": 0.01963476650416851\n",
      "}\n"
     ]
    }
   ],
   "source": [
    "def get_tokens_and_weights(sparse_embedding, tokenizer):\n",
    "    token_weight_dict = {}\n",
    "    for i in range(len(sparse_embedding.indices)):\n",
    "        token = tokenizer.decode([sparse_embedding.indices[i]])\n",
    "        weight = sparse_embedding.values[i]\n",
    "        token_weight_dict[token] = weight\n",
    "\n",
    "    # Sort the dictionary by weights\n",
    "    token_weight_dict = dict(\n",
    "        sorted(token_weight_dict.items(), key=lambda item: item[1], reverse=True)\n",
    "    )\n",
    "    return token_weight_dict\n",
    "\n",
    "\n",
    "# Test the function with the first SparseEmbedding\n",
    "print(json.dumps(get_tokens_and_weights(sparse_embeddings_list[index], tokenizer), indent=4))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Observations and Model Design Choices\n",
    "\n",
    "1. The relative order of importance is quite useful. The most important tokens in the sentence have the highest weights.\n",
    "1. **Term Expansion**: The model can expand the terms in the document. This means that the model can generate weights for tokens that are not present in the document but are related to the tokens in the document. This is a powerful feature that allows the model to capture the context of the document. Here, you'll see that the model has added the tokens '3' from 'third' and 'moon' from 'lunar' to the sparse vector.\n",
    "\n",
    "### Design Choices\n",
    "\n",
    "1. The weights are not normalized. This means that the sum of the weights is not 1 or 100. This is a common practice in sparse embeddings, as it allows the model to capture the importance of each token in the document.\n",
    "1. Tokens are included in the sparse vector only if they are present in the model's vocabulary. This means that the model will not generate a weight for tokens that it has not seen during training.\n",
    "1. Tokens do not map to words directly -- allowing you to gracefully handle typo errors and out-of-vocabulary tokens."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": false
   },
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "fst",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.13"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
